{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will convert Tensorflow weights of DeepID model to Keras format.\n",
    "\n",
    "DeepID2 model: https://arxiv.org/pdf/1406.4773.pdf\n",
    "\n",
    "Pre-trained TensorFlow weights: https://github.com/Ruoyiran/DeepID/tree/master/model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import keras\n",
    "from keras.models import Model, Sequential\n",
    "from keras.layers import Conv2D, Activation, Input, Add\n",
    "from keras.layers.core import Dense, Flatten, Dropout\n",
    "from keras.layers.pooling import MaxPooling2D"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build DeepID Keras model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_deepid_keras_model():\n",
    "    \n",
    "    myInput = Input(shape=(55, 47, 3))\n",
    "    \n",
    "    x = Conv2D(20, (4, 4), name='Conv1', activation='relu', input_shape=(55, 47, 3))(myInput)\n",
    "    x = MaxPooling2D(pool_size=2, strides=2, name='Pool1')(x)\n",
    "    x = Dropout(rate=1, name='D1')(x)\n",
    "    \n",
    "    x = Conv2D(40, (3, 3), name='Conv2', activation='relu')(x)\n",
    "    x = MaxPooling2D(pool_size=2, strides=2, name='Pool2')(x)\n",
    "    x = Dropout(rate=1, name='D2')(x)\n",
    "    \n",
    "    x = Conv2D(60, (3, 3), name='Conv3', activation='relu')(x)\n",
    "    x = MaxPooling2D(pool_size=2, strides=2, name='Pool3')(x)\n",
    "    x = Dropout(rate=1, name='D3')(x)\n",
    "    \n",
    "    #--------------------------------------\n",
    "    \n",
    "    x1 = Flatten()(x)\n",
    "    fc11 = Dense(160, name = 'fc11')(x1)\n",
    "    \n",
    "    #--------------------------------------\n",
    "    \n",
    "    x2 = Conv2D(80, (2, 2), name='Conv4', activation='relu')(x)\n",
    "    x2 = Flatten()(x2)\n",
    "    fc12 = Dense(160, name = 'fc12')(x2)\n",
    "    \n",
    "    #--------------------------------------\n",
    "    \n",
    "    y = Add()([fc11, fc12])\n",
    "    y = Activation('relu', name = 'deepid')(y)\n",
    "    \n",
    "    model = Model(inputs=[myInput], outputs=y)\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "deepid_keras_model = create_deepid_keras_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            (None, 55, 47, 3)    0                                            \n",
      "__________________________________________________________________________________________________\n",
      "Conv1 (Conv2D)                  (None, 52, 44, 20)   980         input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "Pool1 (MaxPooling2D)            (None, 26, 22, 20)   0           Conv1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "D1 (Dropout)                    (None, 26, 22, 20)   0           Pool1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "Conv2 (Conv2D)                  (None, 24, 20, 40)   7240        D1[0][0]                         \n",
      "__________________________________________________________________________________________________\n",
      "Pool2 (MaxPooling2D)            (None, 12, 10, 40)   0           Conv2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "D2 (Dropout)                    (None, 12, 10, 40)   0           Pool2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "Conv3 (Conv2D)                  (None, 10, 8, 60)    21660       D2[0][0]                         \n",
      "__________________________________________________________________________________________________\n",
      "Pool3 (MaxPooling2D)            (None, 5, 4, 60)     0           Conv3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "D3 (Dropout)                    (None, 5, 4, 60)     0           Pool3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "Conv4 (Conv2D)                  (None, 4, 3, 80)     19280       D3[0][0]                         \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 1200)         0           D3[0][0]                         \n",
      "__________________________________________________________________________________________________\n",
      "flatten_2 (Flatten)             (None, 960)          0           Conv4[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "fc11 (Dense)                    (None, 160)          192160      flatten_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "fc12 (Dense)                    (None, 160)          153760      flatten_2[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "add_1 (Add)                     (None, 160)          0           fc11[0][0]                       \n",
      "                                                                 fc12[0][0]                       \n",
      "__________________________________________________________________________________________________\n",
      "deepid (Activation)             (None, 160)          0           add_1[0][0]                      \n",
      "==================================================================================================\n",
      "Total params: 395,080\n",
      "Trainable params: 395,080\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "deepid_keras_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert TF weights to Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A Pretty Guide to convert TensorFlow weights to Keras: https://github.com/nyoki-mtl/keras-facenet/blob/master/notebook/tf_to_keras.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_model_dir = 'tf-model/'\n",
    "npy_weights_dir = 'keras-model/npy_weights/'\n",
    "weights_dir = 'keras-model/weights/'\n",
    "model_dir = 'keras-model/model/'\n",
    "\n",
    "weights_filename = 'deepid_keras_weights.h5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(npy_weights_dir, exist_ok=True)\n",
    "os.makedirs(weights_dir, exist_ok=True)\n",
    "os.makedirs(model_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "re_repeat = re.compile(r'Repeat_[0-9_]*b')\n",
    "re_block8 = re.compile(r'Block8_[A-Za-z]')\n",
    "\n",
    "def get_filename(key):\n",
    "    filename = str(key)\n",
    "    filename = filename.replace('/', '_')\n",
    "    filename = re_repeat.sub('B', filename)\n",
    "\n",
    "    # from TF to Keras naming\n",
    "    filename = filename.replace('_weights', '_kernel')\n",
    "    filename = filename.replace('_biases', '_bias')\n",
    "\n",
    "    return filename + '.npy'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_tensors_from_checkpoint_file(filename, output_folder):\n",
    "    reader = tf.train.NewCheckpointReader(filename)\n",
    "\n",
    "    for key in reader.get_variable_to_shape_map():\n",
    "        # not saving the following tensors\n",
    "        if key == 'global_step':\n",
    "            continue\n",
    "        if 'AuxLogit' in key:\n",
    "            continue\n",
    "\n",
    "        # convert tensor name into the corresponding Keras layer weight name and save\n",
    "        #print(\"key: \", key)\n",
    "        path = os.path.join(output_folder, get_filename(key))\n",
    "        arr = reader.get_tensor(key)\n",
    "        np.save(path, arr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TensorFlow weights exist in this repo: https://github.com/Ruoyiran/DeepID/tree/master/model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "extract_tensors_from_checkpoint_file(tf_model_dir+'best_model.ckpt', npy_weights_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading numpy weights from keras-model/npy_weights/\n",
      "Conv1_conv2d_kernel.npy\n",
      "Conv1_conv2d_bias.npy\n",
      "Conv2_conv2d_kernel.npy\n",
      "Conv2_conv2d_bias.npy\n",
      "Conv3_conv2d_kernel.npy\n",
      "Conv3_conv2d_bias.npy\n",
      "Conv4_conv2d_kernel.npy\n",
      "Conv4_conv2d_bias.npy\n",
      "DeepID_fc11_kernel.npy\n",
      "DeepID_fc11_bias.npy\n",
      "DeepID_fc12_kernel.npy\n",
      "DeepID_fc12_bias.npy\n"
     ]
    }
   ],
   "source": [
    "print('Loading numpy weights from', npy_weights_dir)\n",
    "\n",
    "for layer in deepid_keras_model.layers:\n",
    "    \n",
    "    if layer.weights:\n",
    "        weights = []\n",
    "        \n",
    "        for w in layer.weights:\n",
    "            weight_name = os.path.basename(w.name).replace(':0', '')\n",
    "            \n",
    "            if isinstance(layer, keras.layers.Conv2D):\n",
    "                layer_name = layer.name + \"_conv2d\"\n",
    "            if isinstance(layer, keras.layers.Dense):\n",
    "                layer_name = \"DeepID_\" + layer.name\n",
    "            \n",
    "            weight_file = layer_name + '_' + weight_name + '.npy'\n",
    "            \n",
    "            print(weight_file)\n",
    "            \n",
    "            weight_arr = np.load(os.path.join(npy_weights_dir, weight_file))\n",
    "            weights.append(weight_arr)\n",
    "\n",
    "        layer.set_weights(weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "deepid_keras_model.save_weights(os.path.join(weights_dir, weights_filename))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
